{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.torch_basics import *\n",
    "from fastai.data.all import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp text.core\n",
    "#default_cls_lvl 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Text core\n",
    "\n",
    "> Basic function to preprocess text before assembling it in a `DataLoaders`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export \n",
    "import spacy,html\n",
    "from spacy.symbols import ORTH"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocessing rules"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following are rules applied to texts before or after it's tokenized."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "#special tokens\n",
    "UNK, PAD, BOS, EOS, FLD, TK_REP, TK_WREP, TK_UP, TK_MAJ = \"xxunk xxpad xxbos xxeos xxfld xxrep xxwrep xxup xxmaj\".split()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "_all_ = [\"UNK\", \"PAD\", \"BOS\", \"EOS\", \"FLD\", \"TK_REP\", \"TK_WREP\", \"TK_UP\", \"TK_MAJ\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "_re_spec = re.compile(r'([/#\\\\])')\n",
    "\n",
    "def spec_add_spaces(t):\n",
    "    \"Add spaces around / and #\"\n",
    "    return _re_spec.sub(r' \\1 ', t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(spec_add_spaces('#fastai'), ' # fastai')\n",
    "test_eq(spec_add_spaces('/fastai'), ' / fastai')\n",
    "test_eq(spec_add_spaces('\\\\fastai'), ' \\\\ fastai')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "_re_space = re.compile(' {2,}')\n",
    "\n",
    "def rm_useless_spaces(t):\n",
    "    \"Remove multiple spaces\"\n",
    "    return _re_space.sub(' ', t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(rm_useless_spaces('a  b   c'), 'a b c')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "_re_rep = re.compile(r'(\\S)(\\1{2,})')\n",
    "\n",
    "def replace_rep(t):\n",
    "    \"Replace repetitions at the character level: cccc -- TK_REP 4 c\"\n",
    "    def _replace_rep(m):\n",
    "        c,cc = m.groups()\n",
    "        return f' {TK_REP} {len(cc)+1} {c} '\n",
    "    return _re_rep.sub(_replace_rep, t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It starts replacing at 3 repetitions of the same character or more."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(replace_rep('aa'), 'aa')\n",
    "test_eq(replace_rep('aaaa'), f' {TK_REP} 4 a ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "_re_wrep = re.compile(r'(?:\\s|^)(\\w+)\\s+((?:\\1\\s+)+)\\1(\\s|\\W|$)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "\"\"\"\n",
    "Matches any word repeated at least four times with spaces between them\n",
    "(?:\\s|^)       Non-Capture either a whitespace character or the beginning of text\n",
    "(\\w+)          Capture any alphanumeric character\n",
    "\\s+            One or more whitespace\n",
    "((?:\\1\\s+)+)   Capture a repetition of one or more times \\1 followed by one or more whitespace\n",
    "\\1             Occurrence of \\1\n",
    "(\\s|\\W|$)      Capture last whitespace, non alphanumeric character or end of text\n",
    "\"\"\";"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def replace_wrep(t):\n",
    "    \"Replace word repetitions: word word word word -- TK_WREP 4 word\"\n",
    "    def _replace_wrep(m):\n",
    "        c,cc,e = m.groups()\n",
    "        return f' {TK_WREP} {len(cc.split())+2} {c} {e}'\n",
    "    return _re_wrep.sub(_replace_wrep, t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It starts replacing at 3 repetitions of the same word or more."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(replace_wrep('ah ah'), 'ah ah')\n",
    "test_eq(replace_wrep('ah ah ah'), f' {TK_WREP} 3 ah ')\n",
    "test_eq(replace_wrep('ah ah   ah  ah'), f' {TK_WREP} 4 ah ')\n",
    "test_eq(replace_wrep('ah ah ah ah '), f' {TK_WREP} 4 ah  ')\n",
    "test_eq(replace_wrep('ah ah ah ah.'), f' {TK_WREP} 4 ah .')\n",
    "test_eq(replace_wrep('ah ah ahi'), f'ah ah ahi')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def fix_html(x):\n",
    "    \"Various messy things we've seen in documents\"\n",
    "    x = x.replace('#39;', \"'\").replace('amp;', '&').replace('#146;', \"'\").replace('nbsp;', ' ').replace(\n",
    "        '#36;', '$').replace('\\\\n', \"\\n\").replace('quot;', \"'\").replace('<br />', \"\\n\").replace(\n",
    "        '\\\\\"', '\"').replace('<unk>',UNK).replace(' @.@ ','.').replace(' @-@ ','-').replace('...',' …')\n",
    "    return html.unescape(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(fix_html('#39;bli#146;'), \"'bli'\")\n",
    "test_eq(fix_html('Sarah amp; Duck...'), 'Sarah & Duck …')\n",
    "test_eq(fix_html('a nbsp; #36;'), 'a   $')\n",
    "test_eq(fix_html('\\\\\" <unk>'), f'\" {UNK}')\n",
    "test_eq(fix_html('quot;  @.@  @-@ '), \"' .-\")\n",
    "test_eq(fix_html('<br />text\\\\n'), '\\ntext\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "_re_all_caps = re.compile(r'(\\s|^)([A-Z]+[^a-z\\s]*)(?=(\\s|$))')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "\"\"\"\n",
    "Catches any word in all caps, even with ' or - inside\n",
    "(\\s|^)        Capture either a whitespace or the beginning of text\n",
    "([A-Z]+       Capture one capitalized letter or more...\n",
    "[^a-z\\s]*)    ...followed by anything that's non lowercase or whitespace\n",
    "(?=(\\s|$))    Look ahead for a space or end of text\n",
    "\"\"\";"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def replace_all_caps(t):\n",
    "    \"Replace tokens in ALL CAPS by their lower version and add `TK_UP` before.\"\n",
    "    def _replace_all_caps(m):\n",
    "        tok = f'{TK_UP} ' if len(m.groups()[1]) > 1 else ''\n",
    "        return f\"{m.groups()[0]}{tok}{m.groups()[1].lower()}\"\n",
    "    return _re_all_caps.sub(_replace_all_caps, t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(replace_all_caps(\"I'M SHOUTING\"), f\"{TK_UP} i'm {TK_UP} shouting\")\n",
    "test_eq(replace_all_caps(\"I'm speaking normally\"), \"I'm speaking normally\")\n",
    "test_eq(replace_all_caps(\"I am speaking normally\"), \"i am speaking normally\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "_re_maj = re.compile(r'(\\s|^)([A-Z][^A-Z\\s]*)(?=(\\s|$))')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "\"\"\"\n",
    "Catches any capitalized word\n",
    "(\\s|^)       Capture either a whitespace or the beginning of text\n",
    "([A-Z]       Capture exactly one capitalized letter...\n",
    "[^A-Z\\s]*)   ...followed by anything that's not uppercase or whitespace\n",
    "(?=(\\s|$))   Look ahead for a space of end of text\n",
    "\"\"\";"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def replace_maj(t):\n",
    "    \"Replace tokens in Sentence Case by their lower version and add `TK_MAJ` before.\"\n",
    "    def _replace_maj(m):\n",
    "        tok = f'{TK_MAJ} ' if len(m.groups()[1]) > 1 else ''\n",
    "        return f\"{m.groups()[0]}{tok}{m.groups()[1].lower()}\"\n",
    "    return _re_maj.sub(_replace_maj, t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(replace_maj(\"Jeremy Howard\"), f'{TK_MAJ} jeremy {TK_MAJ} howard')\n",
    "test_eq(replace_maj(\"I don't think there is any maj here\"), (\"i don't think there is any maj here\"),)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def lowercase(t, add_bos=True, add_eos=False):\n",
    "    \"Converts `t` to lowercase\"\n",
    "    return (f'{BOS} ' if add_bos else '') + t.lower().strip() + (f' {EOS}' if add_eos else '')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def replace_space(t):\n",
    "    \"Replace embedded spaces in a token with unicode line char to allow for split/join\"\n",
    "    return t.replace(' ', '▁')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "defaults.text_spec_tok = [UNK, PAD, BOS, EOS, FLD, TK_REP, TK_WREP, TK_UP, TK_MAJ]\n",
    "defaults.text_proc_rules = [fix_html, replace_rep, replace_wrep, spec_add_spaces, rm_useless_spaces,\n",
    "                            replace_all_caps, replace_maj, lowercase]\n",
    "defaults.text_postproc_rules = [replace_space]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tokenizing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A tokenizer is a class that must implement `__call__`. This method receives a iterator of texts and must return a generator with their tokenized versions. Here is the most basic example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class BaseTokenizer():\n",
    "    \"Basic tokenizer that just splits on spaces\"\n",
    "    def __init__(self, split_char=' ', **kwargs): self.split_char=split_char\n",
    "    def __call__(self, items): return (t.split(self.split_char) for t in items)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tok = BaseTokenizer()\n",
    "test_eq(tok([\"This is a text\"]), [[\"This\", \"is\", \"a\", \"text\"]])\n",
    "tok = BaseTokenizer('x')\n",
    "test_eq(tok([\"This is a text\"]), [[\"This is a te\", \"t\"]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class SpacyTokenizer():\n",
    "    \"Spacy tokenizer for `lang`\"\n",
    "    def __init__(self, lang='en', special_toks=None, buf_sz=5000):\n",
    "        self.special_toks = ifnone(special_toks, defaults.text_spec_tok)\n",
    "        nlp = spacy.blank(lang)\n",
    "        for w in self.special_toks: nlp.tokenizer.add_special_case(w, [{ORTH: w}])\n",
    "        self.pipe,self.buf_sz = nlp.pipe,buf_sz\n",
    "\n",
    "    def __call__(self, items):\n",
    "        return (L(doc).attrgot('text') for doc in self.pipe(map(str,items), batch_size=self.buf_sz))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "WordTokenizer = SpacyTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tok = SpacyTokenizer()\n",
    "inp,exp = \"This isn't the easiest text.\",[\"This\", \"is\", \"n't\", \"the\", \"easiest\", \"text\", \".\"]\n",
    "test_eq(L(tok([inp,inp])), [exp,exp])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class TokenizeWithRules:\n",
    "    \"A wrapper around `tok` which applies `rules`, then tokenizes, then applies `post_rules`\"\n",
    "    def __init__(self, tok, rules=None, post_rules=None):\n",
    "        self.rules = L(ifnone(rules, defaults.text_proc_rules))\n",
    "        self.post_f = compose(*L(ifnone(post_rules, defaults.text_postproc_rules)))\n",
    "        self.tok = tok\n",
    "\n",
    "    def __call__(self, batch):\n",
    "        return (L(o).map(self.post_f) for o in self.tok(maps(*self.rules, batch)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = TokenizeWithRules(BaseTokenizer(),rules=[replace_all_caps])\n",
    "test_eq(f([\"THIS isn't a problem\"]), [[TK_UP, 'this', \"isn't\", 'a', 'problem']])\n",
    "f = TokenizeWithRules(SpacyTokenizer())\n",
    "test_eq(f([\"This isn't a problem\"]), [[BOS, TK_MAJ, 'this', 'is', \"n't\", 'a', 'problem']])\n",
    "f = TokenizeWithRules(BaseTokenizer(split_char=\"'\"), rules=[])\n",
    "test_eq(f([\"This isn't a problem\"]), [['This▁isn', 't▁a▁problem']])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The main function that will be called during one of the processes handling tokenization. It will iterate through the `batch` of texts, apply them `rules` and tokenize them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "texts = [\"this is a text\", \"this is another text\"]\n",
    "tok = TokenizeWithRules(BaseTokenizer(), texts.__getitem__)\n",
    "test_eq(tok([0,1]), [['this', 'is', 'a', 'text'],['this', 'is', 'another', 'text']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(TokenizeWithRules)\n",
    "def tokenize1(text, tok, **kwargs):\n",
    "    \"Call `TokenizeWithRules` with a single text\"\n",
    "    return first(TokenizeWithRules(tok=tok, **kwargs)([text]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(tokenize1(\"This isn't a problem\", SpacyTokenizer()),\n",
    "        [BOS, TK_MAJ, 'this', 'is', \"n't\", 'a', 'problem'])\n",
    "test_eq(tokenize1(\"This isn't a problem\", tok=BaseTokenizer(), rules=[]),\n",
    "        ['This',\"isn't\",'a','problem'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def parallel_tokenize(items, tok=None, rules=None, n_workers=defaults.cpus, **kwargs):\n",
    "    \"Calls optional `setup` on `tok` before launching `TokenizeWithRules` using `parallel_gen\"\n",
    "    if tok is None: tok = WordTokenizer()\n",
    "    if hasattr(tok, 'setup'): tok.setup(items, rules)\n",
    "    return parallel_gen(TokenizeWithRules, items, tok=tok, rules=rules, n_workers=n_workers, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that since this uses `parallel_gen` behind the scenes, the generator returned contains tuples of indices and results. There is no guarantee that the results are returned in order, so you should sort by the first item of the tuples (the indices) if you need them ordered."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "res  = parallel_tokenize(['0 1', '1 2'], rules=[], n_workers=2)\n",
    "idxs,toks = zip(*L(res).sorted(itemgetter(0)))\n",
    "test_eq(toks, [['0','1'],['1','2']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "res1 = parallel_tokenize(['0 1', '1 2'], tok=BaseTokenizer(), rules=[], n_workers=0)\n",
    "idxs1,toks1 = zip(*L(res1).sorted(itemgetter(0)))\n",
    "test_eq(toks, toks1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tokenize texts in files"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Preprocessing function for texts in filenames. Tokenized texts will be saved in a similar fashion in a directory suffixed with `_tok` in the parent folder of `path` (override with `output_dir`). This directory is the return value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "fn_counter_pkl = 'counter.pkl'\n",
    "fn_lengths_pkl = 'lengths.pkl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _tokenize_files(func, files, path, output_dir=None, output_names=None, n_workers=defaults.cpus, rules=None, tok=None,\n",
    "                   encoding='utf8', skip_if_exists=False):\n",
    "    \"Tokenize text `files` in parallel using `n_workers`\"\n",
    "    if tok is None: tok = WordTokenizer()\n",
    "    output_dir = Path(ifnone(output_dir, path.parent/f'{path.name}_tok'))\n",
    "    if skip_if_exists and output_dir.exists(): return output_dir\n",
    "    output_dir.mkdir(exist_ok=True)\n",
    "    if output_names is None: output_names = L(output_dir/f.relative_to(path) for f in files)\n",
    "    rules = partial(Path.read_text, encoding=encoding) + L(ifnone(rules, defaults.text_proc_rules.copy()))\n",
    "\n",
    "    lengths,counter = {},Counter()\n",
    "    for i,tok in parallel_tokenize(files, tok, rules, n_workers=n_workers):\n",
    "        out = func(i,output_dir)\n",
    "        out.mk_write(' '.join(tok), encoding=encoding)\n",
    "        lengths[str(files[i].relative_to(path))] = len(tok)\n",
    "        counter.update(tok)\n",
    "\n",
    "    save_pickle(output_dir/fn_lengths_pkl, lengths)\n",
    "    save_pickle(output_dir/fn_counter_pkl, counter)\n",
    "    return output_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(_tokenize_files)\n",
    "def tokenize_folder(path, extensions=None, folders=None, output_dir=None, skip_if_exists=True, **kwargs):\n",
    "    \"Tokenize text files in `path` in parallel using `n_workers`\"\n",
    "    path,extensions = Path(path),ifnone(extensions, ['.txt'])\n",
    "    files = get_files(path, extensions=extensions, recurse=True, folders=folders)\n",
    "    def _f(i,output_dir): return output_dir/files[i].relative_to(path)\n",
    "    return _tokenize_files(_f, files, path, skip_if_exists=skip_if_exists, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The result will be in `output_dir` (defaults to a folder in the same parent directory as `path`, with `_tok` added to `path.name`) with the same structure as in `path`. Tokenized texts for a given file will be in the file having the same name in `output_dir`. Additionally, a file with a .len suffix contains the number of tokens and the count of all words is stored in `output_dir/counter.pkl`.\n",
    "\n",
    "`extensions` will default to `['.txt']` and all text files in `path` are treated unless you specify a list of folders in `include`. `rules` (that defaults to `defaults.text_proc_rules`) are applied to each text before going in the tokenizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(_tokenize_files)\n",
    "def tokenize_files(files, path, output_dir, output_names=None, **kwargs):\n",
    "    \"Tokenize text `files` in parallel using `n_workers`\"\n",
    "    if output_names is None: output_names = L(output_dir/f.relative_to(path) for f in files)\n",
    "    def _f(i,output_dir): return output_dir/output_names[i]\n",
    "    return _tokenize_files(_f, files, path, output_dir=output_dir, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tokenize texts in a dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _join_texts(df, mark_fields=False):\n",
    "    \"Join texts in row `idx` of `df`, marking each field with `FLD` if `mark_fields=True`\"\n",
    "    text_col = (f'{FLD} {1} ' if mark_fields else '' ) + df.iloc[:,0].astype(str)\n",
    "    for i in range(1,len(df.columns)):\n",
    "        text_col += (f' {FLD} {i+1} ' if mark_fields else ' ') + df.iloc[:,i].astype(str)\n",
    "    return text_col.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "texts = [f\"This is an example of text {i}\" for i in range(10)]\n",
    "df = pd.DataFrame({'text': texts, 'text1': texts}, columns=['text', 'text1'])\n",
    "col = _join_texts(df, mark_fields=True)    \n",
    "\n",
    "for i in range(len(df)):\n",
    "    test_eq(col[i], f'{FLD} 1 This is an example of text {i} {FLD} 2 This is an example of text {i}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def tokenize_texts(texts, n_workers=defaults.cpus, rules=None, tok=None):\n",
    "    \"Tokenize `texts` in parallel using `n_workers`\"\n",
    "    rules = L(ifnone(rules, defaults.text_proc_rules.copy()))\n",
    "    outputs = L(parallel_tokenize(texts, tok=tok, rules=rules, n_workers=n_workers)\n",
    "               ).sorted().itemgot(1)\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def tokenize_df(df, text_cols, n_workers=defaults.cpus, rules=None, mark_fields=None,\n",
    "                tok=None, tok_text_col=\"text\"):\n",
    "    \"Tokenize texts in `df[text_cols]` in parallel using `n_workers` and stores them in `df[tok_text_col]`\"\n",
    "    text_cols = [df.columns[c] if isinstance(c, int) else c for c in L(text_cols)]\n",
    "    #mark_fields defaults to False if there is one column of texts, True if there are multiple\n",
    "    if mark_fields is None: mark_fields = len(text_cols)>1\n",
    "    rules = L(ifnone(rules, defaults.text_proc_rules.copy()))\n",
    "    texts = _join_texts(df[text_cols], mark_fields=mark_fields)\n",
    "    outputs = L(parallel_tokenize(texts, tok, rules, n_workers=n_workers)\n",
    "               ).sorted().itemgot(1)\n",
    "\n",
    "    other_cols = df.columns[~df.columns.isin(text_cols)]\n",
    "    res = df[other_cols].copy()\n",
    "    res[tok_text_col] = outputs\n",
    "    res[f'{tok_text_col}_length'] = [len(o) for o in outputs]\n",
    "    return res,Counter(outputs.concat())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This function returns a new dataframe with the same non-text columns, a column named text that contains the tokenized texts and a column named text_lengths that contains their respective length. It also returns a counter of all seen words to quickly build a vocabulary afterward.\n",
    "\n",
    "`rules` (that defaults to `defaults.text_proc_rules`) are applied to each text before going in the tokenizer. If `mark_fields` isn't specified, it defaults to `False` when there is a single text column, `True` when there are several. In that case, the texts in each of those columns are joined with `FLD` markers followed by the number of the field."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def tokenize_csv(fname, text_cols, outname=None, n_workers=4, rules=None, mark_fields=None,\n",
    "                 tok=None, header='infer', chunksize=50000):\n",
    "    \"Tokenize texts in the `text_cols` of the csv `fname` in parallel using `n_workers`\"\n",
    "    df = pd.read_csv(fname, header=header, chunksize=chunksize)\n",
    "    outname = Path(ifnone(outname, fname.parent/f'{fname.stem}_tok.csv'))\n",
    "    cnt = Counter()\n",
    "\n",
    "    for i,dfp in enumerate(df):\n",
    "        out,c = tokenize_df(dfp, text_cols, n_workers=n_workers, rules=rules,\n",
    "                            mark_fields=mark_fields, tok=tok)\n",
    "        out.text = out.text.str.join(' ')\n",
    "        out.to_csv(outname, header=(None,header)[i==0], index=False, mode=('a','w')[i==0])\n",
    "        cnt.update(c)\n",
    "\n",
    "    save_pickle(outname.with_suffix('.pkl'), cnt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def load_tokenized_csv(fname):\n",
    "    \"Utility function to quickly load a tokenized csv ans the corresponding counter\"\n",
    "    fname = Path(fname)\n",
    "    out = pd.read_csv(fname)\n",
    "    for txt_col in out.columns[1:-1]:\n",
    "        out[txt_col] = tuple(out[txt_col].str.split(' '))\n",
    "    return out,load_pickle(fname.with_suffix('.pkl'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The result will be written in a new csv file in `outname` (defaults to the same as `fname` with the suffix `_tok.csv`) and will have the same header as the original file, the same non-text columns, a text and a text_lengths column as described in `tokenize_df`.\n",
    "\n",
    "`rules` (that defaults to `defaults.text_proc_rules`) are applied to each text before going in the tokenizer. If `mark_fields` isn't specified, it defaults to `False` when there is a single text column, `True` when there are several. In that case, the texts in each of those columns are joined with `FLD` markers followed by the number of the field.\n",
    "\n",
    "The csv file is opened with `header` and optionally with blocks of `chunksize` at a time. If this argument is passed, each chunk is processed independently and saved in the output file to save memory usage."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _prepare_texts(tmp_d):\n",
    "    \"Prepare texts in a folder struct in tmp_d, a csv file and returns a dataframe\"\n",
    "    path = Path(tmp_d)/'tmp'\n",
    "    path.mkdir()\n",
    "    for d in ['a', 'b', 'c']: \n",
    "        (path/d).mkdir()\n",
    "        for i in range(5):\n",
    "            with open(path/d/f'text{i}.txt', 'w') as f: f.write(f\"This is an example of text {d} {i}\")\n",
    "    \n",
    "    texts = [f\"This is an example of text {d} {i}\" for i in range(5) for d in ['a', 'b', 'c']]\n",
    "    df = pd.DataFrame({'text': texts, 'label': list(range(15))}, columns=['text', 'label'])\n",
    "    csv_fname = tmp_d/'input.csv'\n",
    "    df.to_csv(csv_fname, index=False)\n",
    "    return path,df,csv_fname"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#hide\n",
    "# integration test\n",
    "with tempfile.TemporaryDirectory() as tmp_d:\n",
    "    path,df,csv_fname = _prepare_texts(Path(tmp_d))\n",
    "    #Tokenize as folders\n",
    "    tokenize_folder(path)\n",
    "    outp = Path(tmp_d)/'tmp_tok'\n",
    "    for d in ['a', 'b', 'c']: \n",
    "        p = outp/d\n",
    "        for i in range(5):\n",
    "            test_eq((p/f'text{i}.txt').read_text(), ' '.join([\n",
    "                BOS, TK_MAJ, 'this', 'is', 'an', 'example', 'of', 'text', d, str(i) ]))\n",
    "    cnt_a = load_pickle(outp/fn_counter_pkl)\n",
    "    test_eq(cnt_a['this'], 15)\n",
    "    test_eq(cnt_a['a'], 5)\n",
    "    test_eq(cnt_a['0'], 3)\n",
    "    \n",
    "    #Tokenize as files\n",
    "    files = get_text_files(path)\n",
    "    tokenize_files(files, path, output_dir=path/'d')\n",
    "    for f in files: \n",
    "        test_eq((path/'d'/f.relative_to(path)).read_text(), ' '.join([\n",
    "                BOS, TK_MAJ, 'this', 'is', 'an', 'example', 'of', 'text', f.parent.name, f.name[4]]))\n",
    "    \n",
    "    #Tokenize as individual texts\n",
    "    out = tokenize_texts(df['text'].values)\n",
    "    test_eq(out, [(outp/d/f'text{i}.txt').read_text().split(' ') for i in range(5) for d in ['a', 'b', 'c']])\n",
    "    \n",
    "    #Tokenize as a dataframe\n",
    "    out,cnt_b = tokenize_df(df, text_cols='text')\n",
    "    test_eq(list(out.columns), ['label', 'text', 'text_length'])\n",
    "    test_eq(out['label'].values, df['label'].values)\n",
    "    test_eq(list(out['text']), [(outp/d/f'text{i}.txt').read_text().split(' ') for i in range(5) for d in ['a', 'b', 'c']])\n",
    "    test_eq(cnt_a, cnt_b)\n",
    "    \n",
    "    #Tokenize as a csv \n",
    "    out_fname = Path(tmp_d)/'output.csv'\n",
    "    tokenize_csv(csv_fname, text_cols='text', outname=out_fname)\n",
    "    a,b = load_tokenized_csv(out_fname)\n",
    "    test_eq((out,cnt_b), load_tokenized_csv(out_fname))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `Tokenizer`-"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class Tokenizer(Transform):\n",
    "    \"Provides a consistent `Transform` interface to tokenizers operating on `DataFrame`s and folders\"\n",
    "    input_types = (str, list, L, tuple, Path)\n",
    "    def __init__(self, tok, rules=None, counter=None, lengths=None, mode=None, sep=' '):\n",
    "        if isinstance(tok,type): tok=tok()\n",
    "        store_attr('tok,counter,lengths,mode,sep')\n",
    "        self.rules = defaults.text_proc_rules if rules is None else rules\n",
    "\n",
    "    @classmethod\n",
    "    @delegates(tokenize_df, keep=True)\n",
    "    def from_df(cls, text_cols, tok=None, rules=None, sep=' ', **kwargs):\n",
    "        if tok is None: tok = WordTokenizer()\n",
    "        res = cls(tok, rules=rules, mode='df')\n",
    "        res.kwargs,res.train_setup = merge({'tok': tok}, kwargs),False\n",
    "        res.text_cols,res.sep = text_cols,sep\n",
    "        return res\n",
    "\n",
    "    @classmethod\n",
    "    @delegates(tokenize_folder, keep=True)\n",
    "    def from_folder(cls, path, tok=None, rules=None, **kwargs):\n",
    "        path = Path(path)\n",
    "        if tok is None: tok = WordTokenizer()\n",
    "        output_dir = tokenize_folder(path, tok=tok, rules=rules, **kwargs)\n",
    "        res = cls(tok, counter=load_pickle(output_dir/fn_counter_pkl),\n",
    "                  lengths=load_pickle(output_dir/fn_lengths_pkl), rules=rules, mode='folder')\n",
    "        res.path,res.output_dir = path,output_dir\n",
    "        return res\n",
    "\n",
    "    def setups(self, dsets):\n",
    "        if not self.mode == 'df' or not isinstance(dsets.items, pd.DataFrame): return\n",
    "        dsets.items,count = tokenize_df(dsets.items, self.text_cols, rules=self.rules, **self.kwargs)\n",
    "        if self.counter is None: self.counter = count\n",
    "        return dsets\n",
    "\n",
    "    def encodes(self, o:Path):\n",
    "        if self.mode=='folder' and str(o).startswith(str(self.path)):\n",
    "            tok = self.output_dir/o.relative_to(self.path)\n",
    "            return L(tok.read_text(encoding='UTF-8').split(' '))\n",
    "        else: return self._tokenize1(o.read_text())\n",
    "\n",
    "    def encodes(self, o:str): return self._tokenize1(o)\n",
    "    def _tokenize1(self, o): return first(self.tok([compose(*self.rules)(o)]))\n",
    "\n",
    "    def get_lengths(self, items):\n",
    "        if self.lengths is None: return None\n",
    "        if self.mode == 'df':\n",
    "            if isinstance(items, pd.DataFrame) and 'text_lengths' in items.columns: return items['text_length'].values\n",
    "        if self.mode == 'folder':\n",
    "            try:\n",
    "                res = [self.lengths[str(Path(i).relative_to(self.path))] for i in items]\n",
    "                if len(res) == len(items): return res\n",
    "            except: return None\n",
    "\n",
    "    def decodes(self, o): return TitledStr(self.sep.join(o))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(['xxbos', 'xxmaj', 'this', 'is', 'an', 'example', 'of', 'text', 'b', '0'],)\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('xxbos', 'xxmaj', 'this', 'is', 'an', 'example', 'of', 'text', 'c', '3')\n"
     ]
    }
   ],
   "source": [
    "with tempfile.TemporaryDirectory() as tmp_d:\n",
    "    path,df,csv_fname = _prepare_texts(Path(tmp_d))\n",
    "    items = get_text_files(path)\n",
    "    splits = RandomSplitter()(items)\n",
    "    dsets = Datasets(items, [Tokenizer.from_folder(path)], splits=splits)\n",
    "    print(dsets.train[0])\n",
    "    \n",
    "    dsets = Datasets(df, [Tokenizer.from_df('text')], splits=splits)\n",
    "    print(dsets.train[0][0].text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = test_set(dsets, ['This is a test', 'this is another test'])\n",
    "test_eq(tst, [(['xxbos', 'xxmaj', 'this','is','a','test'],), \n",
    "              (['xxbos','this','is','another','test'],)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sentencepiece"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "eu_langs = [\"bg\", \"cs\", \"da\", \"de\", \"el\", \"en\", \"es\", \"et\", \"fi\", \"fr\", \"ga\", \"hr\", \"hu\",\n",
    "            \"it\",\"lt\",\"lv\",\"mt\",\"nl\",\"pl\",\"pt\",\"ro\",\"sk\",\"sl\",\"sv\"] # all European langs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class SentencePieceTokenizer():#TODO: pass the special tokens symbol to sp\n",
    "    \"SentencePiece tokenizer for `lang`\"\n",
    "    def __init__(self, lang='en', special_toks=None, sp_model=None, vocab_sz=None, max_vocab_sz=30000,\n",
    "                 model_type='unigram', char_coverage=None, cache_dir='tmp'):\n",
    "        try: from sentencepiece import SentencePieceTrainer,SentencePieceProcessor\n",
    "        except ImportError:\n",
    "            raise Exception('sentencepiece module is missing: run `pip install sentencepiece!=0.1.90,!=0.1.91`')\n",
    "        self.sp_model,self.cache_dir = sp_model,Path(cache_dir)\n",
    "        self.vocab_sz,self.max_vocab_sz,self.model_type = vocab_sz,max_vocab_sz,model_type\n",
    "        self.char_coverage = ifnone(char_coverage, 0.99999 if lang in eu_langs else 0.9998)\n",
    "        self.special_toks = ifnone(special_toks, defaults.text_spec_tok)\n",
    "        if sp_model is None: self.tok = None\n",
    "        else:\n",
    "            self.tok = SentencePieceProcessor()\n",
    "            self.tok.Load(str(sp_model))\n",
    "        os.makedirs(self.cache_dir, exist_ok=True)\n",
    "\n",
    "    def _get_vocab_sz(self, raw_text_path):\n",
    "        cnt = Counter()\n",
    "        with open(raw_text_path, 'r') as f:\n",
    "            for line in f.readlines():\n",
    "                cnt.update(line.split())\n",
    "                if len(cnt)//4 > self.max_vocab_sz: return self.max_vocab_sz\n",
    "        res = len(cnt)//4\n",
    "        while res%8 != 0: res+=1\n",
    "        return max(res,29)\n",
    "\n",
    "    def train(self, raw_text_path):\n",
    "        \"Train a sentencepiece tokenizer on `texts` and save it in `path/tmp_dir`\"\n",
    "        from sentencepiece import SentencePieceTrainer\n",
    "        vocab_sz = self._get_vocab_sz(raw_text_path) if self.vocab_sz is None else self.vocab_sz\n",
    "        spec_tokens = ['\\u2581'+s for s in self.special_toks]\n",
    "        SentencePieceTrainer.Train(\" \".join([\n",
    "            f\"--input={raw_text_path} --vocab_size={vocab_sz} --model_prefix={self.cache_dir/'spm'}\",\n",
    "            f\"--character_coverage={self.char_coverage} --model_type={self.model_type}\",\n",
    "            f\"--unk_id={len(spec_tokens)} --pad_id=-1 --bos_id=-1 --eos_id=-1 --minloglevel=2\",\n",
    "            f\"--user_defined_symbols={','.join(spec_tokens)} --hard_vocab_limit=false\"]))\n",
    "        raw_text_path.unlink()\n",
    "        return self.cache_dir/'spm.model'\n",
    "\n",
    "    def setup(self, items, rules=None):\n",
    "        from sentencepiece import SentencePieceProcessor\n",
    "        if rules is None: rules = []\n",
    "        if self.tok is not None: return {'sp_model': self.sp_model}\n",
    "        raw_text_path = self.cache_dir/'texts.out'\n",
    "        with open(raw_text_path, 'w') as f:\n",
    "            for t in progress_bar(maps(*rules, items), total=len(items), leave=False):\n",
    "                f.write(f'{t}\\n')\n",
    "        sp_model = self.train(raw_text_path)\n",
    "        self.tok = SentencePieceProcessor()\n",
    "        self.tok.Load(str(sp_model))\n",
    "        return {'sp_model': sp_model}\n",
    "\n",
    "    def __call__(self, items):\n",
    "        if self.tok is None: self.setup(items)\n",
    "        for t in items: yield self.tok.EncodeAsPieces(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "SubwordTokenizer = SentencePieceTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "texts = [f\"This is an example of text {i}\" for i in range(10)]\n",
    "df = pd.DataFrame({'text': texts, 'label': list(range(10))}, columns=['text', 'label'])\n",
    "out,cnt = tokenize_df(df, text_cols='text', tok=SentencePieceTokenizer(vocab_sz=34), n_workers=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['▁xx', 'b', 'o', 's', '▁xx', 'm', 'a', 'j', '▁t', 'h', 'i', 's', '▁', 'i', 's', '▁a', 'n', '▁', 'ex', 'a', 'm', 'p', 'l', 'e', '▁', 'o', 'f', '▁t', 'ex', 't', '▁', 'b', '▁', '2']\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['▁xx', 'b', 'o', 's', '▁xx', 'm', 'a', 'j', '▁t', 'h', 'i', 's', '▁', 'i', 's', '▁a', 'n', '▁', 'ex', 'a', 'm', 'p', 'l', 'e', '▁', 'o', 'f', '▁t', 'ex', 't', '▁a', '▁', '4']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jhoward/miniconda3/lib/python3.8/site-packages/numpy/core/_asarray.py:102: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
      "  return array(a, dtype, copy=False, order=order)\n"
     ]
    }
   ],
   "source": [
    "with tempfile.TemporaryDirectory() as tmp_d:\n",
    "    path,df,csv_fname = _prepare_texts(Path(tmp_d))\n",
    "    items = get_text_files(path)\n",
    "    splits = RandomSplitter()(items)\n",
    "    tok = SentencePieceTokenizer(special_toks=[])\n",
    "    dsets = Datasets(items, [Tokenizer.from_folder(path, tok=tok)], splits=splits)\n",
    "    print(dsets.train[0][0])\n",
    "    \n",
    "with warnings.catch_warnings():\n",
    "    dsets = Datasets(df, [Tokenizer.from_df('text', tok=tok)], splits=splits)\n",
    "    print(dsets.train[0][0].text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 01a_losses.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 10b_tutorial.albumentations.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 18b_callback.preds.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.image_sequence.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.azureml.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
